{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ec1aae37",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
      "2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
      "2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")\n",
    "\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "import datasets \n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from datasets import load_dataset, load_metric\n",
    "\n",
    "from transformers import (\n",
    "    AutoModel,\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoTokenizer,\n",
    "    DataCollatorForSeq2Seq,\n",
    "    Seq2SeqTrainingArguments,\n",
    "    Seq2SeqTrainer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5fd7cb0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"t5-small\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "04530b1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the LightningDataModule\n",
    "class MyDataModule(pl.LightningDataModule):\n",
    "    def __init__(self, batch_size):\n",
    "        super().__init__()\n",
    "        self.batch_size = batch_size\n",
    "    \n",
    "    def prepare_data(self):\n",
    "        # Download and preprocess the data\n",
    "        load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
    "        load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
    "    \n",
    "    def setup(self, stage=None):\n",
    "        # Load and preprocess the data\n",
    "        train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
    "        val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
    "\n",
    "        self.train_ds = train_data.map(\n",
    "            self.preprocess_function, \n",
    "            batched=True, \n",
    "            batch_size=self.batch_size, \n",
    "            remove_columns=[\"article\", \"highlights\", \"id\"]\n",
    "        )\n",
    "\n",
    "        self.val_ds = val_data.map(\n",
    "            self.preprocess_function, \n",
    "            batched=True, \n",
    "            batch_size=self.batch_size,\n",
    "            remove_columns=[\"article\", \"highlights\", \"id\"]\n",
    "        )\n",
    "\n",
    "    def preprocess_function(self, batch):\n",
    "        inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
    "        outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
    "        batch[\"input_ids\"] = inputs.input_ids\n",
    "        batch[\"attention_mask\"] = inputs.attention_mask\n",
    "        batch[\"labels\"] = outputs.input_ids.copy()\n",
    "        return batch\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fbb699e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyLightningModule(pl.LightningModule):\n",
    "    def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n",
    "        super().__init__()\n",
    "        self.model_name = model_name\n",
    "        self.learning_rate = learning_rate\n",
    "        self.weight_decay = weight_decay\n",
    "        self.batch_size = batch_size\n",
    "        \n",
    "        # Load the pre-trained model and tokenizer\n",
    "        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
    "\n",
    "        # Load the ROUGE metric\n",
    "        self.metric = load_metric(\"rouge\")\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, labels=None):\n",
    "        output = self.model(\n",
    "            input_ids=input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            labels=labels,\n",
    "        )\n",
    "        return output.loss, output.logits\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        input_ids = batch[\"input_ids\"]\n",
    "        attention_mask = batch[\"attention_mask\"]\n",
    "        labels = batch[\"labels\"]\n",
    "        loss, logits = self(input_ids, attention_mask, labels)\n",
    "        self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
    "        return {'loss': loss, 'logits': logits}\n",
    "    \n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        input_ids = batch[\"input_ids\"]\n",
    "        attention_mask = batch[\"attention_mask\"]\n",
    "        labels = batch[\"labels\"]\n",
    "        loss, logits = self(input_ids, attention_mask, labels)\n",
    "        self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
    "        return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
    "    \n",
    "    def validation_epoch_end(self, outputs):\n",
    "        decoded_preds = []\n",
    "        decoded_labels = []\n",
    "        for output in outputs:\n",
    "            logits = output['logits']\n",
    "            labels = output['labels']\n",
    "            decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
    "            decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "        \n",
    "        scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
    "        \n",
    "        self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
    "        self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
    "        self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
    "        return optimizer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "dd63c628",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
      "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
      "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
      "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
      "\n",
      "  0%|                                                                                                                                               | 0/1795 [00:00<?, ?ba/s]\u001b[A\n",
      "  1%|▉                                                                                                                                    | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
      "  1%|█▉                                                                                                                                   | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
      "  2%|██▊                                                                                                                                  | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
      "  3%|███▋                                                                                                                                 | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
      "  3%|████▌                                                                                                                                | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
      "  4%|█████▍                                                                                                                               | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
      "  5%|██████▎                                                                                                                              | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
      "  5%|███████▎                                                                                                                             | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
      "  6%|████████                                                                                                                            | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
      "  7%|████████▉                                                                                                                           | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
      "  7%|█████████▊                                                                                                                          | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
      "  8%|██████████▋                                                                                                                         | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
      "  9%|███████████▌                                                                                                                        | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
      "  9%|████████████▌                                                                                                                       | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
      " 10%|█████████████▍                                                                                                                      | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
      " 11%|██████████████▎                                                                                                                     | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
      " 11%|███████████████▏                                                                                                                    | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
      " 12%|████████████████                                                                                                                    | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
      " 13%|████████████████▉                                                                                                                   | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
      " 13%|█████████████████▋                                                                                                                  | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
      " 14%|██████████████████▌                                                                                                                 | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
      " 15%|███████████████████▎                                                                                                                | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
      " 15%|████████████████████▏                                                                                                               | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
      " 16%|█████████████████████                                                                                                               | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
      " 17%|█████████████████████▊                                                                                                              | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
      " 17%|██████████████████████▋                                                                                                             | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
      " 18%|███████████████████████▌                                                                                                            | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
      " 18%|████████████████████████▎                                                                                                           | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
      " 19%|█████████████████████████▏                                                                                                          | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
      " 20%|█████████████████████████▉                                                                                                          | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
      " 20%|██████████████████████████▊                                                                                                         | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
      " 21%|███████████████████████████▌                                                                                                        | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
      " 22%|████████████████████████████▍                                                                                                       | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
      " 22%|█████████████████████████████▎                                                                                                      | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
      " 23%|██████████████████████████████                                                                                                      | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
      " 23%|██████████████████████████████▉                                                                                                     | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
      " 24%|███████████████████████████████▊                                                                                                    | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
      " 25%|████████████████████████████████▋                                                                                                   | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
      " 25%|█████████████████████████████████▍                                                                                                  | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
      " 26%|██████████████████████████████████▎                                                                                                 | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
      " 27%|███████████████████████████████████                                                                                                 | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
      " 27%|███████████████████████████████████▉                                                                                                | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
      " 28%|████████████████████████████████████▊                                                                                               | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
      " 28%|█████████████████████████████████████▌                                                                                              | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
      " 29%|██████████████████████████████████████▍                                                                                             | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
      " 30%|███████████████████████████████████████▎                                                                                            | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
      " 30%|████████████████████████████████████████▏                                                                                           | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
      " 31%|████████████████████████████████████████▉                                                                                           | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
      " 32%|█████████████████████████████████████████▊                                                                                          | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
      " 32%|██████████████████████████████████████████▋                                                                                         | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
      " 33%|███████████████████████████████████████████▌                                                                                        | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
      " 34%|████████████████████████████████████████████▎                                                                                       | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
      " 34%|█████████████████████████████████████████████▏                                                                                      | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
      " 35%|██████████████████████████████████████████████                                                                                      | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
      " 36%|██████████████████████████████████████████████▉                                                                                     | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
      " 36%|███████████████████████████████████████████████▊                                                                                    | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
      " 37%|████████████████████████████████████████████████▋                                                                                   | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
      " 37%|█████████████████████████████████████████████████▍                                                                                  | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
      " 38%|██████████████████████████████████████████████████▎                                                                                 | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
      " 39%|███████████████████████████████████████████████████▎                                                                                | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
      " 39%|████████████████████████████████████████████████████                                                                                | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
      " 40%|████████████████████████████████████████████████████▉                                                                               | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
      " 41%|█████████████████████████████████████████████████████▊                                                                              | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
      " 41%|██████████████████████████████████████████████████████▋                                                                             | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
      " 42%|███████████████████████████████████████████████████████▌                                                                            | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
      " 43%|████████████████████████████████████████████████████████▍                                                                           | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
      " 43%|█████████████████████████████████████████████████████████▎                                                                          | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
      " 44%|██████████████████████████████████████████████████████████▏                                                                         | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
      " 45%|███████████████████████████████████████████████████████████                                                                         | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
      " 46%|████████████████████████████████████████████████████████████                                                                        | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
      " 46%|████████████████████████████████████████████████████████████▉                                                                       | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
      " 47%|█████████████████████████████████████████████████████████████▊                                                                      | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
      " 48%|██████████████████████████████████████████████████████████████▋                                                                     | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
      " 48%|███████████████████████████████████████████████████████████████▌                                                                    | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
      " 49%|████████████████████████████████████████████████████████████████▍                                                                   | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
      " 50%|█████████████████████████████████████████████████████████████████▍                                                                  | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
      " 50%|██████████████████████████████████████████████████████████████████▎                                                                 | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
      " 51%|███████████████████████████████████████████████████████████████████▏                                                                | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
      " 52%|████████████████████████████████████████████████████████████████████                                                                | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
      " 52%|████████████████████████████████████████████████████████████████████▉                                                               | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
      " 53%|█████████████████████████████████████████████████████████████████████▊                                                              | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
      " 54%|██████████████████████████████████████████████████████████████████████▋                                                             | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
      " 54%|███████████████████████████████████████████████████████████████████████▋                                                            | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
      " 55%|████████████████████████████████████████████████████████████████████████▌                                                           | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
      " 56%|█████████████████████████████████████████████████████████████████████████▍                                                          | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
      " 56%|█████████████████████████████████████████████████████████████████████████▋                                                         | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 57%|██████████████████████████████████████████████████████████████████████████▌                                                        | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
      " 57%|███████████████████████████████████████████████████████████████████████████▎                                                       | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
      " 58%|████████████████████████████████████████████████████████████████████████████▏                                                      | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
      " 59%|█████████████████████████████████████████████████████████████████████████████                                                      | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
      " 59%|█████████████████████████████████████████████████████████████████████████████▉                                                     | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
      " 60%|██████████████████████████████████████████████████████████████████████████████▊                                                    | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
      " 61%|███████████████████████████████████████████████████████████████████████████████▋                                                   | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
      " 62%|████████████████████████████████████████████████████████████████████████████████▌                                                  | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
      " 62%|█████████████████████████████████████████████████████████████████████████████████▍                                                 | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
      " 63%|██████████████████████████████████████████████████████████████████████████████████▎                                                | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
      " 64%|███████████████████████████████████████████████████████████████████████████████████▏                                               | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
      " 64%|████████████████████████████████████████████████████████████████████████████████████▏                                              | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
      " 65%|█████████████████████████████████████████████████████████████████████████████████████                                              | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
      " 66%|█████████████████████████████████████████████████████████████████████████████████████▉                                             | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
      " 66%|██████████████████████████████████████████████████████████████████████████████████████▊                                            | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
      " 67%|███████████████████████████████████████████████████████████████████████████████████████▋                                           | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
      " 68%|████████████████████████████████████████████████████████████████████████████████████████▌                                          | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
      " 68%|█████████████████████████████████████████████████████████████████████████████████████████▍                                         | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
      " 69%|██████████████████████████████████████████████████████████████████████████████████████████▎                                        | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
      " 70%|███████████████████████████████████████████████████████████████████████████████████████████▏                                       | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
      " 70%|███████████████████████████████████████████████████████████████████████████████████████████▉                                       | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
      " 71%|████████████████████████████████████████████████████████████████████████████████████████████▊                                      | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
      " 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌                                     | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
      " 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎                                    | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
      " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏                                   | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
      " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉                                   | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
      " 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊                                  | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
      " 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌                                 | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
      " 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍                                | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
      " 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎                               | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
      " 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏                              | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
      " 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████                              | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
      " 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉                             | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
      " 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊                            | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
      " 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋                           | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
      " 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌                          | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
      " 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍                         | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
      " 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
      " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████                        | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
      " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉                       | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
      " 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                      | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
      " 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                     | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
      " 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                    | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
      " 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                   | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
      " 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                  | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
      " 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████                  | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
      " 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                 | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
      " 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
      " 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋               | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
      " 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍              | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
      " 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍             | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
      " 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎            | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
      " 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏           | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
      " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████           | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
      " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊          | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
      " 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋         | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
      " 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍        | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
      " 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
      " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████       | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
      " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
      " 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊     | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
      " 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌    | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
      " 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍   | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
      " 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎  | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
      " 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
      " 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
      "\n",
      "  0%|                                                                                                                                                 | 0/84 [00:00<?, ?ba/s]\u001b[A\n",
      " 14%|███████████████████▎                                                                                                                   | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
      " 29%|██████████████████████████████████████▌                                                                                                | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
      " 43%|█████████████████████████████████████████████████████████▊                                                                             | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
      " 56%|███████████████████████████████████████████████████████████████████████████▌                                                           | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
      " 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏                                         | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
      " 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                        | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
      "\n",
      "  | Name  | Type                       | Params\n",
      "-----------------------------------------------------\n",
      "0 | model | T5ForConditionalGeneration | 60.5 M\n",
      "-----------------------------------------------------\n",
      "60.5 M    Trainable params\n",
      "0         Non-trainable params\n",
      "60.5 M    Total params\n",
      "242.026   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0:   0%|                                                                                                                    | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'list' object has no attribute 'size'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m      3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m      4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m    606\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    609\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m    610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m     36\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m     37\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m     41\u001b[0m     trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m    643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m    644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m    645\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m    646\u001b[0m     ckpt_path,  \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m    647\u001b[0m     model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m    648\u001b[0m     model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m    653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m   1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m   1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m   1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m   1181\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m   1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m   1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m   1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m     \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m   1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    198\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    200\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m    201\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m    151\u001b[0m     kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    198\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    200\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m    201\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m    134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m    136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m    140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m    223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m    224\u001b[0m \n\u001b[1;32m    225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    231\u001b[0m \u001b[38;5;124;03m    the outputs of the step\u001b[39;00m\n\u001b[1;32m    232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1482\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m   1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m   1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m    389\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m     34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m     35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m     38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     17\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     18\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     19\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     20\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     21\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m   1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   1623\u001b[0m     \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m     encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1625\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1626\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1627\u001b[0m \u001b[43m        \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1628\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1629\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1630\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1631\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1632\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m   1634\u001b[0m     encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m   1635\u001b[0m         last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m   1636\u001b[0m         hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1637\u001b[0m         attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1638\u001b[0m     )\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    940\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    941\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    942\u001b[0m     )\n\u001b[1;32m    943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m     input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m    945\u001b[0m     input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m    946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
     ]
    }
   ],
   "source": [
    "torch.set_float32_matmul_precision(\"medium\")\n",
    "model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
    "trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
    "dm = MyDataModule(batch_size=16)\n",
    "trainer.fit(model, datamodule=dm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1395d5d2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80a2efab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
